{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X4cRE8IbIrIV"
      },
      "source": [
        "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n",
        "\n",
        "也直接使用google colab notebook打开本教程，下载相关数据集和模型。\n",
        "如果您正在google的colab中打开这个notebook，您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MOsHUjgdIrIW"
      },
      "outputs": [],
      "source": [
        "!pip install transformers datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HFASsisvIrIb"
      },
      "source": [
        "如果您正在本地打开这个notebook，请确保您已经进行上述依赖包的安装。\n",
        "您也可以在[这里](https://github.com/huggingface/transformers/tree/master/examples/text-classification)找到本notebook的多GPU分布式训练版本。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rEJBSTyZIrIb"
      },
      "source": [
        "# 微调预训练模型进行文本分类"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kTCFado4IrIc"
      },
      "source": [
        "我们将展示如何使用 [🤗 Transformers](https://github.com/huggingface/transformers)代码库中的模型来解决文本分类任务，任务来源于[GLUE Benchmark](https://gluebenchmark.com/).\n",
        "\n",
        "![Widget inference on a text classification task](https://github.com/huggingface/notebooks/blob/master/examples/images/text_classification.png?raw=1)\n",
        "\n",
        "GLUE榜单包含了9个句子级别的分类任务，分别是：\n",
        "- [CoLA](https://nyu-mll.github.io/CoLA/) (Corpus of Linguistic Acceptability) 鉴别一个句子是否语法正确.\n",
        "- [MNLI](https://arxiv.org/abs/1704.05426) (Multi-Genre Natural Language Inference) 给定一个假设，判断另一个句子与该假设的关系：entails, contradicts 或者 unrelated。\n",
        "- [MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398) (Microsoft Research Paraphrase Corpus) 判断两个句子是否互为paraphrases.\n",
        "- [QNLI](https://rajpurkar.github.io/SQuAD-explorer/) (Question-answering Natural Language Inference) 判断第2句是否包含第1句问题的答案。\n",
        "- [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs) (Quora Question Pairs2) 判断两个问句是否语义相同。\n",
        "- [RTE](https://aclweb.org/aclwiki/Recognizing_Textual_Entailment) (Recognizing Textual Entailment)判断一个句子是否与假设成entail关系。\n",
        "- [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) 判断一个句子的情感正负向.\n",
        "- [STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) (Semantic Textual Similarity Benchmark) 判断两个句子的相似性（分数为1-5分）。\n",
        "- [WNLI](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not. \n",
        "\n",
        "对于以上任务，我们将展示如何使用简单的Dataset库加载数据集，同时使用transformer中的`Trainer`接口对预训练模型进行微调。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "YZbiBDuGIrId"
      },
      "outputs": [],
      "source": [
        "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4RRkXuteIrIh"
      },
      "source": [
        "本notebook理论上可以使用各种各样的transformer模型（[模型面板](https://huggingface.co/models)），解决任何文本分类分类任务。\n",
        "\n",
        "如果您所处理的任务有所不同，大概率只需要很小的改动便可以使用本notebook进行处理。同时，您应该根据您的GPU显存来调整微调训练所需要的btach size大小，避免显存溢出。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "zVvslsfMIrIh"
      },
      "outputs": [],
      "source": [
        "task = \"cola\"\n",
        "model_checkpoint = \"distilbert-base-uncased\"\n",
        "batch_size = 16"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "whPRbBNbIrIl"
      },
      "source": [
        "## 加载数据"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W7QYTpxXIrIl"
      },
      "source": [
        "我们将会使用[🤗 Datasets](https://github.com/huggingface/datasets)库来加载数据和对应的评测方式。数据加载和评测方式加载只需要简单使用`load_dataset`和`load_metric`即可。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "IreSlFmlIrIm"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset, load_metric"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CKx2zKs5IrIq"
      },
      "source": [
        "除了`mnli-mm`以外，其他任务都可以直接通过任务名字进行加载。数据加载之后会自动缓存。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s_AY1ATSIrIq"
      },
      "outputs": [],
      "source": [
        "actual_task = \"mnli\" if task == \"mnli-mm\" else task\n",
        "dataset = load_dataset(\"glue\", actual_task)\n",
        "metric = load_metric('glue', actual_task)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RzfPtOMoIrIu"
      },
      "source": [
        "这个`datasets`对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集，只需要使用对应的key（train，validation，test）即可得到相应的数据。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GWiVUF0jIrIv",
        "outputId": "b5d9d856-eaa3-4444-c650-1642a797cb77"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['sentence', 'label', 'idx'],\n",
              "        num_rows: 8551\n",
              "    })\n",
              "    validation: Dataset({\n",
              "        features: ['sentence', 'label', 'idx'],\n",
              "        num_rows: 1043\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['sentence', 'label', 'idx'],\n",
              "        num_rows: 1063\n",
              "    })\n",
              "})"
            ]
          },
          "execution_count": 14,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u3EtYfeHIrIz"
      },
      "source": [
        "给定一个数据切分的key（train、validation或者test）和下标即可查看数据。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X6HrpprwIrIz",
        "outputId": "1a1cf3a9-3349-40e6-88e2-912be6462daa"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'idx': 0,\n",
              " 'label': 1,\n",
              " 'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\"}"
            ]
          },
          "execution_count": 15,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset[\"train\"][0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WHUmphG3IrI3"
      },
      "source": [
        "为了能够进一步理解数据长什么样子，下面的函数将从数据集里随机选择几个例子进行展示。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "i3j8APAoIrI3"
      },
      "outputs": [],
      "source": [
        "import datasets\n",
        "import random\n",
        "import pandas as pd\n",
        "from IPython.display import display, HTML\n",
        "\n",
        "def show_random_elements(dataset, num_examples=10):\n",
        "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
        "    picks = []\n",
        "    for _ in range(num_examples):\n",
        "        pick = random.randint(0, len(dataset)-1)\n",
        "        while pick in picks:\n",
        "            pick = random.randint(0, len(dataset)-1)\n",
        "        picks.append(pick)\n",
        "    \n",
        "    df = pd.DataFrame(dataset[picks])\n",
        "    for column, typ in dataset.features.items():\n",
        "        if isinstance(typ, datasets.ClassLabel):\n",
        "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
        "    display(HTML(df.to_html()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 359
        },
        "id": "SZy5tRB_IrI7",
        "outputId": "bf2a2b5b-0bda-41d0-edd4-db568b0cfcc8"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>sentence</th>\n",
              "      <th>label</th>\n",
              "      <th>idx</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>The more I talk to Joe, the less about linguistics I am inclined to think Sally has taught him to appreciate.</td>\n",
              "      <td>acceptable</td>\n",
              "      <td>196</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Have in our class the kids arrived safely?</td>\n",
              "      <td>unacceptable</td>\n",
              "      <td>3748</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>I gave Mary a book.</td>\n",
              "      <td>acceptable</td>\n",
              "      <td>5302</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>Every student, who attended the party, had a good time.</td>\n",
              "      <td>unacceptable</td>\n",
              "      <td>4944</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>Bill pounded the metal fiat.</td>\n",
              "      <td>acceptable</td>\n",
              "      <td>2178</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>It bit me on the leg.</td>\n",
              "      <td>acceptable</td>\n",
              "      <td>5908</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>The boys were made a good mother by Aunt Mary.</td>\n",
              "      <td>unacceptable</td>\n",
              "      <td>736</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>More of a man is here.</td>\n",
              "      <td>unacceptable</td>\n",
              "      <td>5403</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>My mother baked me a birthday cake.</td>\n",
              "      <td>acceptable</td>\n",
              "      <td>3761</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>Gregory appears to have wanted to be loyal to the company.</td>\n",
              "      <td>acceptable</td>\n",
              "      <td>4334</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "show_random_elements(dataset[\"train\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lnjDIuQ3IrI-"
      },
      "source": [
        "评估metic是[`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric)的一个实例:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5o4rUteaIrI_",
        "outputId": "064cdacb-f8e3-432b-d504-715b59330275"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Metric(name: \"glue\", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: \"\"\"\n",
              "Compute GLUE evaluation metric associated to each GLUE dataset.\n",
              "Args:\n",
              "    predictions: list of predictions to score.\n",
              "        Each translation should be tokenized into a list of tokens.\n",
              "    references: list of lists of references for each translation.\n",
              "        Each reference should be tokenized into a list of tokens.\n",
              "Returns: depending on the GLUE subset, one or several of:\n",
              "    \"accuracy\": Accuracy\n",
              "    \"f1\": F1 score\n",
              "    \"pearson\": Pearson Correlation\n",
              "    \"spearmanr\": Spearman Correlation\n",
              "    \"matthews_correlation\": Matthew Correlation\n",
              "Examples:\n",
              "\n",
              "    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of [\"mnli\", \"mnli_mismatched\", \"mnli_matched\", \"qnli\", \"rte\", \"wnli\", \"hans\"]\n",
              "    >>> references = [0, 1]\n",
              "    >>> predictions = [0, 1]\n",
              "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
              "    >>> print(results)\n",
              "    {'accuracy': 1.0}\n",
              "\n",
              "    >>> glue_metric = datasets.load_metric('glue', 'mrpc')  # 'mrpc' or 'qqp'\n",
              "    >>> references = [0, 1]\n",
              "    >>> predictions = [0, 1]\n",
              "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
              "    >>> print(results)\n",
              "    {'accuracy': 1.0, 'f1': 1.0}\n",
              "\n",
              "    >>> glue_metric = datasets.load_metric('glue', 'stsb')\n",
              "    >>> references = [0., 1., 2., 3., 4., 5.]\n",
              "    >>> predictions = [0., 1., 2., 3., 4., 5.]\n",
              "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
              "    >>> print({\"pearson\": round(results[\"pearson\"], 2), \"spearmanr\": round(results[\"spearmanr\"], 2)})\n",
              "    {'pearson': 1.0, 'spearmanr': 1.0}\n",
              "\n",
              "    >>> glue_metric = datasets.load_metric('glue', 'cola')\n",
              "    >>> references = [0, 1]\n",
              "    >>> predictions = [0, 1]\n",
              "    >>> results = glue_metric.compute(predictions=predictions, references=references)\n",
              "    >>> print(results)\n",
              "    {'matthews_correlation': 1.0}\n",
              "\"\"\", stored examples: 0)"
            ]
          },
          "execution_count": 18,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "metric"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jAWdqcUBIrJC"
      },
      "source": [
        "直接调用metric的`compute`方法，传入`labels`和`predictions`即可得到metric的值："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6XN1Rq0aIrJC",
        "outputId": "0fa6d397-4d58-40ce-d14c-e527de32074d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'matthews_correlation': 0.1513518081969605}"
            ]
          },
          "execution_count": 19,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import numpy as np\n",
        "\n",
        "fake_preds = np.random.randint(0, 2, size=(64,))\n",
        "fake_labels = np.random.randint(0, 2, size=(64,))\n",
        "metric.compute(predictions=fake_preds, references=fake_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YOCrQwPoIrJG"
      },
      "source": [
        "每一个文本分类任务所对应的metic有所不同，具体如下:\n",
        "\n",
        "- for CoLA: [Matthews Correlation Coefficient](https://en.wikipedia.org/wiki/Matthews_correlation_coefficient)\n",
        "- for MNLI (matched or mismatched): Accuracy\n",
        "- for MRPC: Accuracy and [F1 score](https://en.wikipedia.org/wiki/F1_score)\n",
        "- for QNLI: Accuracy\n",
        "- for QQP: Accuracy and [F1 score](https://en.wikipedia.org/wiki/F1_score)\n",
        "- for RTE: Accuracy\n",
        "- for SST-2: Accuracy\n",
        "- for STS-B: [Pearson Correlation Coefficient](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) and [Spearman's_Rank_Correlation_Coefficient](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient)\n",
        "- for WNLI: Accuracy\n",
        "\n",
        "所以一定要将metric和任务对齐"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n9qywopnIrJH"
      },
      "source": [
        "## 数据预处理"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YVx71GdAIrJH"
      },
      "source": [
        "在将数据喂入模型之前，我们需要对数据进行预处理。预处理的工具叫`Tokenizer`。`Tokenizer`首先对输入进行tokenize，然后将tokens转化为预模型中需要对应的token ID，再转化为模型需要的输入格式。\n",
        "\n",
        "为了达到数据预处理的目的，我们使用`AutoTokenizer.from_pretrained`方法实例化我们的tokenizer，这样可以确保：\n",
        "\n",
        "- 我们得到一个与预训练模型一一对应的tokenizer。\n",
        "- 使用指定的模型checkpoint对应的tokenizer的时候，我们也下载了模型需要的词表库vocabulary，准确来说是tokens vocabulary。\n",
        "\n",
        "这个被下载的tokens vocabulary会被缓存起来，从而再次使用的时候不会重新下载。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eXNLu_-nIrJI"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer\n",
        "    \n",
        "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vl6IidfdIrJK"
      },
      "source": [
        "注意：`use_fast=True`要求tokenizer必须是transformers.PreTrainedTokenizerFast类型，因为我们在预处理的时候需要用到fast tokenizer的一些特殊特性（比如多线程快速tokenizer）。如果对应的模型没有fast tokenizer，去掉这个选项即可。\n",
        "\n",
        "几乎所有模型对应的tokenizer都有对应的fast tokenizer。我们可以在[模型tokenizer对应表](https://huggingface.co/transformers/index.html#bigtable)里查看所有预训练模型对应的tokenizer所拥有的特点。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rowT4iCLIrJK"
      },
      "source": [
        "tokenizer既可以对单个文本进行预处理，也可以对一对文本进行预处理，tokenizer预处理后得到的数据满足预训练模型输入格式"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a5hBlsrHIrJL",
        "outputId": "cf52c5d7-2273-453f-e78c-0f0e16fb8e0e"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
            ]
          },
          "execution_count": 21,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qo_0B1M2IrJM"
      },
      "source": [
        "取决于我们选择的预训练模型，我们将会看到tokenizer有不同的返回，tokenizer和预训练模型是一一对应的，更多信息可以在[这里](https://huggingface.co/transformers/preprocessing.html)进行学习。\n",
        "\n",
        "为了预处理我们的数据，我们需要知道不同数据和对应的数据格式，因此我们定义下面这个dict。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "fyGdtK9oIrJM"
      },
      "outputs": [],
      "source": [
        "task_to_keys = {\n",
        "    \"cola\": (\"sentence\", None),\n",
        "    \"mnli\": (\"premise\", \"hypothesis\"),\n",
        "    \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
        "    \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
        "    \"qnli\": (\"question\", \"sentence\"),\n",
        "    \"qqp\": (\"question1\", \"question2\"),\n",
        "    \"rte\": (\"sentence1\", \"sentence2\"),\n",
        "    \"sst2\": (\"sentence\", None),\n",
        "    \"stsb\": (\"sentence1\", \"sentence2\"),\n",
        "    \"wnli\": (\"sentence1\", \"sentence2\"),\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xbqtC4MrIrJO"
      },
      "source": [
        "对数据格式进行检查:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "19GG646uIrJO",
        "outputId": "7091c67f-253d-486d-8407-90ac209cf49e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Sentence: Our friends won't buy this analysis, let alone the next one we propose.\n"
          ]
        }
      ],
      "source": [
        "sentence1_key, sentence2_key = task_to_keys[task]\n",
        "if sentence2_key is None:\n",
        "    print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
        "else:\n",
        "    print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
        "    print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2C0hcmp9IrJQ"
      },
      "source": [
        "随后将预处理的代码放到一个函数中："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "vc0BSBLIIrJQ"
      },
      "outputs": [],
      "source": [
        "def preprocess_function(examples):\n",
        "    if sentence2_key is None:\n",
        "        return tokenizer(examples[sentence1_key], truncation=True)\n",
        "    return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0lm8ozrJIrJR"
      },
      "source": [
        "预处理函数可以处理单个样本，也可以对多个样本进行处理。如果输入是多个样本，那么返回的是一个list："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-b70jh26IrJS",
        "outputId": "cea1c3e7-2633-445b-8aee-58ca8d9bcf66"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'input_ids': [[101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 1998, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 2030, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 1996, 2062, 2057, 2817, 16025, 1010, 1996, 13675, 16103, 2121, 2027, 2131, 1012, 102], [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}"
            ]
          },
          "execution_count": 25,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "preprocess_function(dataset['train'][:5])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zS-6iXTkIrJT"
      },
      "source": [
        "接下来对数据集datasets里面的所有样本进行预处理，处理的方式是使用map函数，将预处理函数prepare_train_features应用到（map)所有样本上。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DDtsaJeVIrJT"
      },
      "outputs": [],
      "source": [
        "encoded_dataset = dataset.map(preprocess_function, batched=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "voWiw8C7IrJV"
      },
      "source": [
        "\n",
        "更好的是，返回的结果会自动被缓存，避免下次处理的时候重新计算（但是也要注意，如果输入有改动，可能会被缓存影响！）。datasets库函数会对输入的参数进行检测，判断是否有变化，如果没有变化就使用缓存数据，如果有变化就重新处理。但如果输入参数不变，想改变输入的时候，最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外，上面使用到的`batched=True`这个参数是tokenizer的特点，以为这会使用多线程同时并行对输入进行处理。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "545PP3o8IrJV"
      },
      "source": [
        "## 微调预训练模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FBiW8UpKIrJW"
      },
      "source": [
        "既然数据已经准备好了，现在我们需要下载并加载我们的预训练模型，然后微调预训练模型。既然我们是做seq2seq任务，那么我们需要一个能解决这个任务的模型类。我们使用`AutoModelForSequenceClassification` 这个类。和tokenizer相似，`from_pretrained`方法同样可以帮助我们下载并加载模型，同时也会对模型进行缓存，就不会重复下载模型啦。\n",
        "\n",
        "需要注意的是：STS-B是一个回归问题，MNLI是一个3分类问题：\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 154,
          "referenced_widgets": [
            "fc7dcd466978430fa564001d1fa8959d",
            "477b25083c3541efb4339e0ceef4e522",
            "8ead287f55c1422fab004d24cbc23d62",
            "ac5fb6c061aa432b82f1b8a499ece40a",
            "cae8ea12d2c943d4ad2f1c4812237916",
            "b778869a30d64bd5aa628163a83b8c45",
            "9b7051f36fe1427ab13a8effbe89d2e3",
            "bd75bd9e0afe4759aa2f747f8562a55c",
            "d7f5cbfcaa3e41f783b24dee39d49e6a",
            "d90bc3f2b14746d18636fe3dbe490c04",
            "79b6d1ffa7b6423fab3ca70b116a75c2"
          ]
        },
        "id": "TlqNaB8jIrJW",
        "outputId": "4d048460-7bb0-44a7-dca6-75ce19bebe5c"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "fc7dcd466978430fa564001d1fa8959d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight']\n",
            "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        }
      ],
      "source": [
        "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
        "\n",
        "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
        "model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CczA5lJlIrJX"
      },
      "source": [
        "由于我们微调的任务是文本分类任务，而我们加载的是预训练的语言模型，所以会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数（比如：预训练语言模型的神经网络head被扔掉了，同时随机初始化了文本分类的神经网络head）。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_N8urzhyIrJY"
      },
      "source": [
        "为了能够得到一个`Trainer`训练工具，我们还需要3个要素，其中最重要的是训练的设定/参数 [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments)。这个训练设定包含了能够定义训练过程的所有属性。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "id": "Bliy8zgjIrJY"
      },
      "outputs": [],
      "source": [
        "metric_name = \"pearson\" if task == \"stsb\" else \"matthews_correlation\" if task == \"cola\" else \"accuracy\"\n",
        "\n",
        "args = TrainingArguments(\n",
        "    \"test-glue\",\n",
        "    evaluation_strategy = \"epoch\",\n",
        "    save_strategy = \"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=batch_size,\n",
        "    per_device_eval_batch_size=batch_size,\n",
        "    num_train_epochs=5,\n",
        "    weight_decay=0.01,\n",
        "    load_best_model_at_end=True,\n",
        "    metric_for_best_model=metric_name,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "km3pGVdTIrJc"
      },
      "source": [
        "上面evaluation_strategy = \"epoch\"参数告诉训练代码：我们每个epcoh会做一次验证评估。\n",
        "\n",
        "上面batch_size在这个notebook之前定义好了。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7sZOdRlRIrJd"
      },
      "source": [
        "最后，由于不同的任务需要不同的评测指标，我们定一个函数来根据任务名字得到评价方法:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "id": "UmvbnJ9JIrJd"
      },
      "outputs": [],
      "source": [
        "def compute_metrics(eval_pred):\n",
        "    predictions, labels = eval_pred\n",
        "    if task != \"stsb\":\n",
        "        predictions = np.argmax(predictions, axis=1)\n",
        "    else:\n",
        "        predictions = predictions[:, 0]\n",
        "    return metric.compute(predictions=predictions, references=labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rXuFTAzDIrJe"
      },
      "source": [
        "全部传给 `Trainer`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "imY1oC3SIrJf"
      },
      "outputs": [],
      "source": [
        "validation_key = \"validation_mismatched\" if task == \"mnli-mm\" else \"validation_matched\" if task == \"mnli\" else \"validation\"\n",
        "trainer = Trainer(\n",
        "    model,\n",
        "    args,\n",
        "    train_dataset=encoded_dataset[\"train\"],\n",
        "    eval_dataset=encoded_dataset[validation_key],\n",
        "    tokenizer=tokenizer,\n",
        "    compute_metrics=compute_metrics\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CdzABDVcIrJg"
      },
      "source": [
        "开始训练:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "uNx5pyRlIrJh",
        "outputId": "bc13d4b6-b797-4522-bd3b-887b4a01e3e4"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following columns in the training set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.\n",
            "***** Running training *****\n",
            "  Num examples = 8551\n",
            "  Num Epochs = 5\n",
            "  Instantaneous batch size per device = 16\n",
            "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
            "  Gradient Accumulation steps = 1\n",
            "  Total optimization steps = 2675\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2675' max='2675' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2675/2675 02:49, Epoch 5/5]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Matthews Correlation</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.525400</td>\n",
              "      <td>0.520955</td>\n",
              "      <td>0.409248</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.351600</td>\n",
              "      <td>0.570341</td>\n",
              "      <td>0.477499</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.236100</td>\n",
              "      <td>0.622785</td>\n",
              "      <td>0.499872</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.166300</td>\n",
              "      <td>0.806475</td>\n",
              "      <td>0.491623</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.125700</td>\n",
              "      <td>0.882225</td>\n",
              "      <td>0.513900</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 1043\n",
            "  Batch size = 16\n",
            "Saving model checkpoint to test-glue/checkpoint-535\n",
            "Configuration saved in test-glue/checkpoint-535/config.json\n",
            "Model weights saved in test-glue/checkpoint-535/pytorch_model.bin\n",
            "tokenizer config file saved in test-glue/checkpoint-535/tokenizer_config.json\n",
            "Special tokens file saved in test-glue/checkpoint-535/special_tokens_map.json\n",
            "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 1043\n",
            "  Batch size = 16\n",
            "Saving model checkpoint to test-glue/checkpoint-1070\n",
            "Configuration saved in test-glue/checkpoint-1070/config.json\n",
            "Model weights saved in test-glue/checkpoint-1070/pytorch_model.bin\n",
            "tokenizer config file saved in test-glue/checkpoint-1070/tokenizer_config.json\n",
            "Special tokens file saved in test-glue/checkpoint-1070/special_tokens_map.json\n",
            "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 1043\n",
            "  Batch size = 16\n",
            "Saving model checkpoint to test-glue/checkpoint-1605\n",
            "Configuration saved in test-glue/checkpoint-1605/config.json\n",
            "Model weights saved in test-glue/checkpoint-1605/pytorch_model.bin\n",
            "tokenizer config file saved in test-glue/checkpoint-1605/tokenizer_config.json\n",
            "Special tokens file saved in test-glue/checkpoint-1605/special_tokens_map.json\n",
            "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 1043\n",
            "  Batch size = 16\n",
            "Saving model checkpoint to test-glue/checkpoint-2140\n",
            "Configuration saved in test-glue/checkpoint-2140/config.json\n",
            "Model weights saved in test-glue/checkpoint-2140/pytorch_model.bin\n",
            "tokenizer config file saved in test-glue/checkpoint-2140/tokenizer_config.json\n",
            "Special tokens file saved in test-glue/checkpoint-2140/special_tokens_map.json\n",
            "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 1043\n",
            "  Batch size = 16\n",
            "Saving model checkpoint to test-glue/checkpoint-2675\n",
            "Configuration saved in test-glue/checkpoint-2675/config.json\n",
            "Model weights saved in test-glue/checkpoint-2675/pytorch_model.bin\n",
            "tokenizer config file saved in test-glue/checkpoint-2675/tokenizer_config.json\n",
            "Special tokens file saved in test-glue/checkpoint-2675/special_tokens_map.json\n",
            "\n",
            "\n",
            "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
            "\n",
            "\n",
            "Loading best model from test-glue/checkpoint-2675 (score: 0.5138995234247261).\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=2675, training_loss=0.27181456521292713, metrics={'train_runtime': 169.649, 'train_samples_per_second': 252.02, 'train_steps_per_second': 15.768, 'total_flos': 229537542078168.0, 'train_loss': 0.27181456521292713, 'epoch': 5.0})"
            ]
          },
          "execution_count": 31,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CKASz-2vIrJi"
      },
      "source": [
        "训练完成后进行评估:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 207
        },
        "id": "UOUcBkX8IrJi",
        "outputId": "29f903c0-9124-4a20-b2e4-159d39a265a2"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.\n",
            "***** Running Evaluation *****\n",
            "  Num examples = 1043\n",
            "  Batch size = 16\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='66' max='66' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [66/66 00:00]\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "{'epoch': 5.0,\n",
              " 'eval_loss': 0.8822253346443176,\n",
              " 'eval_matthews_correlation': 0.5138995234247261,\n",
              " 'eval_runtime': 0.9319,\n",
              " 'eval_samples_per_second': 1119.255,\n",
              " 'eval_steps_per_second': 70.825}"
            ]
          },
          "execution_count": 32,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.evaluate()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ffP-VQOyIrJk"
      },
      "source": [
        "To see how your model fared you can compare it to the [GLUE Benchmark leaderboard](https://gluebenchmark.com/leaderboard)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7k8ge1L1IrJk"
      },
      "source": [
        "## 超参数搜索"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RNfajuw_IrJl"
      },
      "source": [
        "`Trainer`同样支持超参搜索，使用[optuna](https://optuna.org/) or [Ray Tune](https://docs.ray.io/en/latest/tune/)代码库。\n",
        "\n",
        "反注释下面两行安装依赖："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YUdakNBhIrJl"
      },
      "outputs": [],
      "source": [
        "! pip install optuna\n",
        "! pip install ray[tune]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ttfT0CqaIrJm"
      },
      "source": [
        "超参搜索时，`Trainer`将会返回多个训练好的模型，所以需要传入一个定义好的模型从而让`Trainer`可以不断重新初始化该传入的模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "id": "8sgjdLKcIrJm"
      },
      "outputs": [],
      "source": [
        "def model_init():\n",
        "    return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mMXfVJO4IrJo"
      },
      "source": [
        "和之前调用 `Trainer`类似:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "71pt6N0eIrJo",
        "outputId": "8e1a7982-0e3a-491f-bac4-f53114d6a384"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.d423bdf2f58dc8b77d5f5d18028d7ae4a72dcfd8f468e81fe979ada957a8c361\n",
            "Model config DistilBertConfig {\n",
            "  \"activation\": \"gelu\",\n",
            "  \"architectures\": [\n",
            "    \"DistilBertForMaskedLM\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.1,\n",
            "  \"dim\": 768,\n",
            "  \"dropout\": 0.1,\n",
            "  \"hidden_dim\": 3072,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"max_position_embeddings\": 512,\n",
            "  \"model_type\": \"distilbert\",\n",
            "  \"n_heads\": 12,\n",
            "  \"n_layers\": 6,\n",
            "  \"pad_token_id\": 0,\n",
            "  \"qa_dropout\": 0.1,\n",
            "  \"seq_classif_dropout\": 0.2,\n",
            "  \"sinusoidal_pos_embds\": false,\n",
            "  \"tie_weights_\": true,\n",
            "  \"transformers_version\": \"4.9.1\",\n",
            "  \"vocab_size\": 30522\n",
            "}\n",
            "\n",
            "loading weights file https://huggingface.co/distilbert-base-uncased/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9ab5ee.126183e36667471617ae2f0835fab707baa54b731f991507ebbb55ea85adb12a\n",
            "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight']\n",
            "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        }
      ],
      "source": [
        "trainer = Trainer(\n",
        "    model_init=model_init,\n",
        "    args=args,\n",
        "    train_dataset=encoded_dataset[\"train\"],\n",
        "    eval_dataset=encoded_dataset[validation_key],\n",
        "    tokenizer=tokenizer,\n",
        "    compute_metrics=compute_metrics\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yQxrzFP4IrJq"
      },
      "source": [
        "调用方法`hyperparameter_search`。注意，这个过程可能很久，我们可以先用部分数据集进行超参搜索，再进行全量训练。\n",
        "比如使用1/10的数据进行搜索："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NboJ7kDOIrJq"
      },
      "outputs": [],
      "source": [
        "best_run = trainer.hyperparameter_search(n_trials=10, direction=\"maximize\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gUTD72qCIrJs"
      },
      "source": [
        "`hyperparameter_search`会返回效果最好的模型相关的参数\b："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Psi4JymeIrJs"
      },
      "outputs": [],
      "source": [
        "best_run"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dFdjWbRIIrJu"
      },
      "source": [
        "将`Trainner`设置为搜索到的最好参数，进行训练："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EsJ6sqdGIrJu"
      },
      "outputs": [],
      "source": [
        "for n, v in best_run.hyperparameters.items():\n",
        "    setattr(trainer.args, n, v)\n",
        "\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m09tZgZRIrJv"
      },
      "source": [
        "最后别忘了，查看如何上传模型 ，上传模型到](https://huggingface.co/transformers/model_sharing.html) 到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样，直接用模型名字就能使用您自己上传的模型啦。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NjO3bqHz4d1d"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "4.1-文本分类",
      "provenance": []
    },
    "interpreter": {
      "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b"
    },
    "kernelspec": {
      "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": ""
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "477b25083c3541efb4339e0ceef4e522": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "79b6d1ffa7b6423fab3ca70b116a75c2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8ead287f55c1422fab004d24cbc23d62": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9b7051f36fe1427ab13a8effbe89d2e3",
            "placeholder": "​",
            "style": "IPY_MODEL_b778869a30d64bd5aa628163a83b8c45",
            "value": "Downloading: 100%"
          }
        },
        "9b7051f36fe1427ab13a8effbe89d2e3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ac5fb6c061aa432b82f1b8a499ece40a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d7f5cbfcaa3e41f783b24dee39d49e6a",
            "max": 267967963,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_bd75bd9e0afe4759aa2f747f8562a55c",
            "value": 267967963
          }
        },
        "b778869a30d64bd5aa628163a83b8c45": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "bd75bd9e0afe4759aa2f747f8562a55c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "cae8ea12d2c943d4ad2f1c4812237916": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_79b6d1ffa7b6423fab3ca70b116a75c2",
            "placeholder": "​",
            "style": "IPY_MODEL_d90bc3f2b14746d18636fe3dbe490c04",
            "value": " 268M/268M [00:05&lt;00:00, 46.7MB/s]"
          }
        },
        "d7f5cbfcaa3e41f783b24dee39d49e6a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d90bc3f2b14746d18636fe3dbe490c04": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "fc7dcd466978430fa564001d1fa8959d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8ead287f55c1422fab004d24cbc23d62",
              "IPY_MODEL_ac5fb6c061aa432b82f1b8a499ece40a",
              "IPY_MODEL_cae8ea12d2c943d4ad2f1c4812237916"
            ],
            "layout": "IPY_MODEL_477b25083c3541efb4339e0ceef4e522"
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}